#include "Mesh.h"
#include "Element.h"
#include "Point.h"
#include <iostream>
#include <string>
#include <cmath>
#include <fstream>

#include <Eigen/Sparse>
#include <Eigen/IterativeLinearSolvers>
typedef Eigen::SparseMatrix<double> SpMat;
typedef Eigen::Triplet<double> Tri;
typedef Eigen::VectorXd RHS;
typedef Eigen::VectorXd Solution;

/** 
 * Real solution. Can be used as Dirichlet boundary conditions.
 * 
 * @param _p variable.
 * 
 * @return function value.
 */
double u(const Point<2, double> &_p);

/** 
 * Source term.
 * 
 * @param _p variable.
 * 
 * @return function value.
 */
double f(const Point<2, double> &_p);

template <typename Mesh_Type, typename Element_Type>
double L2_error(Mesh_Type mesh, Solution solution, double u(const Point<2, double> &));

template <typename IrregularMesh_Type, typename Element_Type>
void Irregularmesh2D_Matlab_plot(IrregularMesh_Type &mesh, Solution &solution, const std::string &filename, unsigned int rec);

template <typename IrregularMesh_Type, typename Element_Type>
void Irregularmesh2D_OpenDx_plot(IrregularMesh_Type &mesh, Solution &solution, const std::string &filename, unsigned int rec);

template <typename Mesh_Type, typename Element_Type>
void PossionSolver(Mesh_Type &mesh, Solution &solution, int _acc, double f(const Point<2, double> &));

template <typename RegularMesh_Type, typename Element_Type>
void Regularmesh2D_Matlab_plot(RegularMesh_Type &mesh, Solution &solution, const std::string &filename, unsigned int multiple);

template <typename RegularMesh_Type, typename Element_Type>
void Regularmesh2D_OpenDx_plot(RegularMesh_Type &mesh, Solution &solution, const std::string &filename, unsigned int multiple);

template <typename Mesh_Type, typename Element_Type>
double L2_error(Mesh_Type mesh, Solution solution, int _acc, double u(const Point<2, double> &))
{
    Element_Type ref_ele;
    size_t n_ele = mesh.get_n_grids();
    size_t n_dof = ref_ele.get_n_dofs();
    ///L2error
    ref_ele.read_quad_info("data/triangle.tmp_geo", _acc);
    /// Get the quadrature information from the reference element.
    const std::valarray<Point<2, double>> &quad_pnts = ref_ele.get_quad_points();
    const std::valarray<double> &weights = ref_ele.get_quad_weights();
    int n_quad_pnts = ref_ele.get_n_quad_points();
    double volume = ref_ele.get_volume();

    double error = 0.0;
    for (size_t ei = 0; ei < n_ele; ei++)
    {
        /// Find the structure of the current element.
        const Geometry &the_ele = mesh.get_grid(ei);
        /// Find the global indexes of all the degree of freedoms on the
        /// current element.
        const std::valarray<size_t> &dof = the_ele.get_dof();
        /// Within the indexes, we find the coordinates of the dofs
        /// from the mesh, and set to the reference element. Actually
        /// build a mapping between the global dofs and the local
        /// ones. (The indexes and coordinates of local dofs have set
        /// inside P1 element.)
        for (size_t k = 0; k < n_dof; k++)
            ref_ele.set_global_dof(k, mesh.get_dof(dof[k]));

        /// Begin to assemble ...
        for (size_t l = 0; l < n_quad_pnts; l++)
        {
            ///
            double detJ = ref_ele.global2local_Jacobi_det(quad_pnts[l]);
            double u_sol = 0;
            for (size_t i = 0; i < n_dof; i++)
            {
                double phi_i = ref_ele.basis_function(i, quad_pnts[l]);
                u_sol += phi_i * solution[dof[i]];
            }
            error += (u(ref_ele.local2global(quad_pnts[l])) - u_sol) * (u(ref_ele.local2global(quad_pnts[l])) - u_sol) * weights[l] * detJ * volume;
        }
    }
    error = std::sqrt(error);
    std::cout << "L2 error: " << error << std::endl<< std::endl;
    return error;
};

void refinement(std::valarray<Point<2, double>> &_nodes,
                int local_n1_index, int local_n2_index, int local_n3_index,
                size_t new_index[3], int rec, std::valarray<std::valarray<size_t>> &connections);

int main(int argc, char *argv[])
{
    Solution solution;
    Easymesh mesh1;
    mesh1.readData("data/D");
    // mesh1.readData("test");
    int acc = 2;

    PossionSolver<Easymesh, P1Element<double>>(mesh1, solution, acc, f);
    /// L2 error
    L2_error<Easymesh, P1Element<double>>(mesh1, solution, 4, u);
    /// Output the solution to file in OpenDx format.
    Irregularmesh2D_OpenDx_plot<Easymesh, P1Element<double>>(mesh1, solution, "u11", 0);
    /// Output the solution to file in Matlab format.
    Irregularmesh2D_Matlab_plot<Easymesh, P1Element<double>>(mesh1, solution, "u11", 0);

    mesh1.build_dofs();
    PossionSolver<Easymesh, P2Element<double>>(mesh1, solution, acc, f);
    /// L2 error
    L2_error<Easymesh, P2Element<double>>(mesh1, solution, 6, u);
    /// Output the solution to file in OpenDx format.
    Irregularmesh2D_OpenDx_plot<Easymesh, P2Element<double>>(mesh1, solution, "u21", 2);
    /// Output the solution to file in Matlab format.
    Irregularmesh2D_Matlab_plot<Easymesh, P2Element<double>>(mesh1, solution, "u21", 2);

    int nx = 2, ny = 2;
    RegularMesh mesh2;
    Point<2, double> x0y0({0, 0}), x1y1({1, 1});
    mesh2.set_lbc(x0y0);
    mesh2.set_ruc(x1y1);
    mesh2.set_nx(nx);
    mesh2.set_ny(ny);

    PossionSolver<RegularMesh, P1Element<double>>(mesh2, solution, acc, f);
    /// L2 error
    L2_error<RegularMesh, P1Element<double>>(mesh2, solution, 4, u);
    /// Output the solution to file in OpenDx format.
    Regularmesh2D_OpenDx_plot<RegularMesh, P1Element<double>>(mesh2, solution, "u12", 1);
    /// Output the solution to file in Matlab format.
    Regularmesh2D_Matlab_plot<RegularMesh, P1Element<double>>(mesh2, solution, "u12", 1);

    RegularMeshP2 mesh3(x0y0, x1y1, nx, ny);

    PossionSolver<RegularMeshP2, P2Element<double>>(mesh3, solution, acc, f);
    /// L2 error
    L2_error<RegularMeshP2, P2Element<double>>(mesh3, solution, 6, u);
    /// Output the solution to file in OpenDx format.
    Regularmesh2D_OpenDx_plot<RegularMeshP2, P2Element<double>>(mesh3, solution, "u22", 5);
    /// Output the solution to file in Matlab format.
    Regularmesh2D_Matlab_plot<RegularMeshP2, P2Element<double>>(mesh3, solution, "u22", 5);

    /// Output the solution to file in OpenDx format.
    Irregularmesh2D_OpenDx_plot<RegularMeshP2, P2Element<double>>(mesh3, solution, "u23", 2);
    /// Output the solution to file in Matlab format.
    Irregularmesh2D_Matlab_plot<RegularMeshP2, P2Element<double>>(mesh3, solution, "u23", 2);

    // mesh2 = (RegularMeshP2)mesh2;    //很可惜似乎是不行的

    return 0;
};

double u(const Point<2, double> &_p)
{
    return std::sin(M_PI * _p[0]) * std::sin(2.0 * M_PI * _p[1]);
};

double f(const Point<2, double> &_p)
{
    return 5.0 * M_PI * M_PI * u(_p);
};

void refinement(std::valarray<Point<2, double>> &_nodes,
                int local_n1_index, int local_n2_index, int local_n3_index,
                size_t new_index[3], int rec, std::valarray<std::valarray<size_t>> &connections)
{
    if (local_n1_index == local_n2_index || local_n1_index == local_n3_index || local_n2_index == local_n3_index)
        std::cerr << "two equal node index" << std::endl;
    if (rec > 0)
    {
        int new_node1_index = new_index[0]++;
        int new_node2_index = new_index[0]++;
        int new_node3_index = new_index[0]++;

        _nodes[new_node1_index] = mid_point<2, double>(_nodes[local_n2_index], _nodes[local_n3_index]);
        _nodes[new_node1_index].set_index(new_index[1]++);
        _nodes[new_node2_index] = mid_point<2, double>(_nodes[local_n1_index], _nodes[local_n3_index]);
        _nodes[new_node2_index].set_index(new_index[1]++);
        _nodes[new_node3_index] = mid_point<2, double>(_nodes[local_n1_index], _nodes[local_n2_index]);
        _nodes[new_node3_index].set_index(new_index[1]++);

        refinement(_nodes, local_n1_index, new_node2_index, new_node3_index, new_index, rec - 1, connections);
        refinement(_nodes, local_n2_index, new_node1_index, new_node3_index, new_index, rec - 1, connections);
        refinement(_nodes, local_n3_index, new_node1_index, new_node2_index, new_index, rec - 1, connections);
        refinement(_nodes, new_node1_index, new_node2_index, new_node3_index, new_index, rec - 1, connections);
    }
    else
    {
        connections[new_index[2]][0] = _nodes[local_n1_index].get_index();
        connections[new_index[2]][1] = _nodes[local_n2_index].get_index();
        connections[new_index[2]][2] = _nodes[local_n3_index].get_index();
        new_index[2]++;
    }
}

template <typename Mesh_Type, typename Element_Type>
void PossionSolver(Mesh_Type &mesh, Solution &solution, int _acc, double f(const Point<2, double> &))
{
    Element_Type ref_ele;

    ref_ele.read_quad_info("data/triangle.tmp_geo", _acc);

    const std::valarray<Point<2, double>> &quad_pnts = ref_ele.get_quad_points();
    const std::valarray<double> &weights = ref_ele.get_quad_weights();
    int n_quad_pnts = ref_ele.get_n_quad_points();
    double volume = ref_ele.get_volume();

    size_t n_ele = mesh.get_n_grids();
    size_t n_dof = ref_ele.get_n_dofs();
    size_t n_total_dofs = mesh.get_n_dofs();

    SpMat K(n_total_dofs, n_total_dofs);
    RHS Rhs(n_total_dofs);
    Rhs.setZero(n_total_dofs);
    std::vector<Tri> TriList(n_ele * n_dof * n_dof * n_quad_pnts);
    std::vector<Tri>::iterator it = TriList.begin();
    std::cout << "its size:" << TriList.size() << std::endl;

    for (size_t ei = 0; ei < n_ele; ei++)
    {
        const Geometry &the_ele = mesh.get_grid(ei);
        const std::valarray<size_t> &dof = the_ele.get_dof();
        // std::cout << "ele_id: " << the_ele.get_index() << std::endl;
        for (size_t k = 0; k < n_dof; k++)
            ref_ele.set_global_dof(k, mesh.get_dof(dof[k]));

        for (size_t l = 0; l < n_quad_pnts; l++)
        {
            double detJ = ref_ele.global2local_Jacobi_det(quad_pnts[l]);
            Matrix<double> invJ = ref_ele.local2global_Jacobi(quad_pnts[l]);

            for (size_t i = 0; i < n_dof; i++)
            {
                std::valarray<double> grad_i = invJ * ref_ele.basis_gradient(i, quad_pnts[l]);
                double phi_i = ref_ele.basis_function(i, quad_pnts[l]);
                for (size_t j = 0; j < n_dof; j++)
                {
                    std::valarray<double> grad_j = invJ * ref_ele.basis_gradient(j, quad_pnts[l]);
                    double cont = (grad_i[0] * grad_j[0] + grad_i[1] * grad_j[1]) * weights[l] * detJ * volume;
                    // A(dof[i], dof[j]) += cont;
                    *it = Tri(dof[i], dof[j], cont);
                    it++;
                }
                /// rhs(i) = \int f(x_i) phi_i dx
                // rhs[dof[i]] += f(ref_ele.local2global(quad_pnts[l])) * phi_i * weights[l] * detJ * volume;

                Rhs[dof[i]] += f(ref_ele.local2global(quad_pnts[l])) * phi_i * weights[l] * detJ * volume;
            }
        }
    }

    K.setFromTriplets(TriList.begin(), TriList.end());
    K.makeCompressed();

    for (size_t i = 0; i < n_total_dofs; i++)
    {
        const Point<2, double> &global_dof = mesh.get_dof(i);
        int bm = global_dof.get_boundary_mark();
        if (bm == 0)
            continue;
        else if (bm == 1)
        {
            double boundary_value = u(global_dof);
            Rhs[i] = boundary_value * K.coeffRef(i, i);
            for (Eigen::SparseMatrix<double>::InnerIterator it(K, i); it; ++it)
            {
                size_t row = it.row();
                if (row == i)
                    continue;
                Rhs[row] -= K.coeffRef(i, row) * boundary_value;
                K.coeffRef(i, row) = 0.0;
                K.coeffRef(row, i) = 0.0;
            }
        }
    };

    Eigen::ConjugateGradient<Eigen::SparseMatrix<double>> Solver_sparse;
    Solver_sparse.setTolerance(1e-15);
    Solver_sparse.compute(K);
    solution = Solver_sparse.solve(Rhs);

    // /// stupid error!
    // double error = 0.0;
    // for (size_t i = 0; i < n_total_dofs; i++)
    //     error += (solution[i] - u(global_dofs[i])) * (solution[i] - u(global_dofs[i]));
    // error = std::sqrt(error);
    // std::cout << "stupid error: " << error << std::endl;
}

template <typename IrregularMesh_Type, typename Element_Type>
void Irregularmesh2D_Matlab_plot(IrregularMesh_Type &mesh, Solution &solution, const std::string &filename, unsigned int rec)
{
    Element_Type ref_ele;
    size_t n_ele = mesh.get_n_grids();
    int side_per_grid = pow(3, rec + 1);
    int node_per_grid = 2 + pow(4, rec); //3+3*(1+4+16+...)
    int ele_per_grid = pow(4, rec);
    int n_dof = ref_ele.get_n_dofs();

    std::valarray<std::valarray<Point<2, double>>> nodes;
    std::valarray<std::valarray<std::valarray<size_t>>> connections;
    std::valarray<std::valarray<double>> values;
    nodes.resize(n_ele);
    connections.resize(n_ele);
    values.resize(n_ele);

    size_t node_index[3] = {0, 0, 0}; //local_index,global_index,local_element_index
    for (size_t ei = 0; ei < n_ele; ei++)
    {
        nodes[ei].resize(node_per_grid);
        connections[ei].resize(ele_per_grid);
        values[ei].resize(node_per_grid);

        const Geometry &the_ele = mesh.get_grid(ei);
        const std::valarray<size_t> &dof = the_ele.get_dof();
        node_index[0] = 3;
        node_index[2] = 0;
        for (size_t k = 0; k < n_dof; k++)
        {
            ref_ele.set_global_dof(k, mesh.get_dof(dof[k]));
        }
        for (size_t k = 0; k < 3; k++)
        {
            nodes[ei][k] = mesh.get_dof(dof[k]);
            nodes[ei][k].set_index(node_index[1]++);
        }
        refinement(nodes[ei], 0, 1, 2, node_index, rec, connections[ei]);
        for (size_t k = 0; k < node_per_grid; k++)
        {
            Point<2, double> local_pnt = ref_ele.global2local(nodes[ei][k]);

            double u_sol = 0;
            for (size_t i = 0; i < n_dof; i++)
            {
                double phi_i = ref_ele.basis_function(i, local_pnt);
                u_sol += phi_i * solution[dof[i]];
            }
            values[ei][k] = u_sol;
        }
    }

    /// Output the solution to file in Matlab format.
    std::ofstream output((filename + ".m").c_str());
    output << "T=zeros(" << n_ele * ele_per_grid << ","
           << "3);" << std::endl;
    output << "P=zeros(" << n_ele * node_per_grid << ","
           << "3);" << std::endl;

    for (size_t ei = 0; ei < n_ele; ei++)
    {
        for (size_t k = 0; k < node_per_grid; k++)
        {
            output << "P(" << ei * node_per_grid + k + 1 << ",:)=[" << nodes[ei][k][0] << "\t" << nodes[ei][k][1] << "\t" << values[ei][k] << "];" << std::endl;
        }
    }

    for (size_t ei = 0; ei < n_ele; ei++)
    {
        for (size_t k = 0; k < ele_per_grid; k++)
        {
            output << "T(" << ei * ele_per_grid + k + 1 << ",:)=["
                   << connections[ei][k][0] + 1 << "\t" << connections[ei][k][1] + 1 << "\t" << connections[ei][k][2] + 1 << "];" << std::endl;
        }
    }

    output << "TR=triangulation(T,P(:,1),P(:,2),P(:,3));" << std::endl;
    output << "trisurf(TR);" << std::endl;
    output.close();
}

template <typename IrregularMesh_Type, typename Element_Type>
void Irregularmesh2D_OpenDx_plot(IrregularMesh_Type &mesh, Solution &solution, const std::string &filename, unsigned int rec)
{
    if (rec < 0)
    {
        std::cerr << "reccurance cannot be negative!!" << std::endl;
        rec = 0;
    }
    Element_Type ref_ele;
    size_t n_ele = mesh.get_n_grids();
    int n_dof = ref_ele.get_n_dofs();
    int side_per_grid = pow(3, rec + 1);
    int node_per_grid = 2 + pow(4, rec); //3+3*(1+4+16+...)
    int ele_per_grid = pow(4, rec);

    std::valarray<std::valarray<Point<2, double>>> nodes;
    std::valarray<std::valarray<std::valarray<size_t>>> connections;
    std::valarray<std::valarray<double>> values;
    nodes.resize(n_ele);
    connections.resize(n_ele);
    values.resize(n_ele);

    size_t node_index[3] = {0, 0, 0}; //local_index,global_index,local_element_index
    for (size_t ei = 0; ei < n_ele; ei++)
    {
        nodes[ei].resize(node_per_grid);
        connections[ei].resize(ele_per_grid);
        values[ei].resize(node_per_grid);

        const Geometry &the_ele = mesh.get_grid(ei);
        const std::valarray<size_t> &dof = the_ele.get_dof();
        node_index[0] = 3;
        node_index[2] = 0;
        for (size_t k = 0; k < n_dof; k++)
        {
            ref_ele.set_global_dof(k, mesh.get_dof(dof[k]));
        }
        for (size_t k = 0; k < 3; k++)
        {
            nodes[ei][k] = mesh.get_dof(dof[k]);
            nodes[ei][k].set_index(node_index[1]++);
        }
        refinement(nodes[ei], 0, 1, 2, node_index, rec, connections[ei]);
        for (size_t k = 0; k < node_per_grid; k++)
        {
            Point<2, double> local_pnt = ref_ele.global2local(nodes[ei][k]);

            double u_sol = 0;
            for (size_t i = 0; i < n_dof; i++)
            {
                double phi_i = ref_ele.basis_function(i, local_pnt);
                u_sol += phi_i * solution[dof[i]];
            }
            values[ei][k] = u_sol;
        }
    }

    /// Output the solution to file in OpenDx format.
    std::ofstream output((filename + ".dx").c_str());
    output << "object 1 class array type float rank 1 shape 2 item "
           << n_ele * node_per_grid
           << " data follows" << std::endl;
    output.setf(std::ios::fixed);

    output.precision(20);
    for (size_t ei = 0; ei < n_ele; ei++)
    {
        for (size_t k = 0; k < node_per_grid; k++)
        {
            output << nodes[ei][k][0] << "\t" << nodes[ei][k][1] << std::endl;
        }
    }

    output << std::endl;
    output << "object 2 class array type int rank 1 shape 3 item "
           << n_ele * ele_per_grid << " data follows" << std::endl;

    for (size_t ei = 0; ei < n_ele; ei++)
    {
        for (size_t k = 0; k < ele_per_grid; k++)
        {
            output << connections[ei][k][0] << "\t" << connections[ei][k][1] << "\t" << connections[ei][k][2] << "\t" << std::endl;
        }
    }
    output << "attribute \"element type\" string \"triangles\"" << std::endl;
    output << "attribute \"ref\" string \"positions\"" << std::endl;
    output << std::endl;
    output << "object 3 class array type float rank 0 item "
           << n_ele * node_per_grid
           << " data follows" << std::endl;
    for (size_t ei = 0; ei < n_ele; ei++)
    {
        for (size_t k = 0; k < node_per_grid; k++)
        {
            output << values[ei][k] << std::endl;
        }
    }

    output << "attribute \"dep\" string \"positions\"" << std::endl;
    output << std::endl;
    output << "object \"FEMFunction-2d\" class field" << std::endl;
    output << "component \"positions\" value 1" << std::endl;
    output << "component \"connections\" value 2" << std::endl;
    output << "component \"data\" value 3" << std::endl;
    output << "end" << std::endl;

    output.close();
}

template <typename RegularMesh_Type, typename Element_Type>
void Regularmesh2D_Matlab_plot(RegularMesh_Type &mesh, Solution &solution, const std::string &filename, unsigned int multiple)
{
    if (multiple == 0)
    {
        std::cerr << "multiple cannot be 0 !!" << std::endl;
        multiple = 1;
    }
    /// Output the solution to file in Matlab format.
    std::ofstream output((filename + ".m").c_str());
    size_t n_ele = mesh.get_n_grids();
    size_t nx = mesh.get_nx(), ny = mesh.get_ny();
    double hx = mesh.get_hx(), hy = mesh.get_hy();
    Element_Type ref_ele;
    int n_dof = ref_ele.get_n_dofs();
    output
        << "x=linspace(" << mesh.get_lbc()[0] << "," << mesh.get_ruc()[0] << "," << multiple * nx + 1 << ");" << std::endl;
    output << "y=linspace(" << mesh.get_lbc()[1] << "," << mesh.get_ruc()[1] << "," << multiple * ny + 1 << ");" << std::endl;
    output << "z=zeros(" << multiple * ny + 1 << "," << multiple * nx + 1 << ");" << std::endl;

    for (size_t ei = 0; ei < n_ele; ei++)
    {
        const Geometry &the_ele = mesh.get_grid(ei);
        const std::valarray<size_t> &dof = the_ele.get_dof();
        for (size_t k = 0; k < n_dof; k++)
        {
            ref_ele.set_global_dof(k, mesh.get_dof(dof[k]));
        }
        if (ei % 2 == 0)
        {
            size_t ix = (ei / 2) % nx;
            size_t iy = (ei / 2) / nx;
            for (int iix = 0; iix < multiple; iix++)
                for (int iiy = 0; iiy <= multiple - iix; iiy++)
                {
                    Point<2, double> global_pnt({(ix + 1. * iix / multiple) * hx, (iy + 1. * iiy / multiple) * hy});
                    Point<2, double> local_pnt = ref_ele.global2local(global_pnt);
                    double u_sol = 0;
                    for (size_t i = 0; i < n_dof; i++)
                    {
                        double phi_i = ref_ele.basis_function(i, local_pnt);
                        u_sol += phi_i * solution[dof[i]];
                    }
                    output << "z(" << iy * multiple + iiy + 1 << "," << ix * multiple + iix + 1 << ")=" << u_sol << ";" << std::endl;
                }
        }
        else
        {
            size_t ix = ((ei - 1) / 2) % nx;
            size_t iy = ((ei - 1) / 2) / nx;
            for (int iix = 0; iix < multiple; iix++)
                for (int iiy = multiple - iix; iiy < multiple; iiy++)
                {
                    Point<2, double> global_pnt({(ix + 1. * iix / multiple) * hx, (iy + 1. * iiy / multiple) * hy});
                    Point<2, double> local_pnt = ref_ele.global2local(global_pnt);
                    double u_sol = 0;
                    for (size_t i = 0; i < n_dof; i++)
                    {
                        double phi_i = ref_ele.basis_function(i, local_pnt);
                        u_sol += phi_i * solution[dof[i]];
                    }
                    output << "z(" << iy * multiple + iiy + 1 << "," << ix * multiple + iix + 1 << ")=" << u_sol << ";" << std::endl;
                }
        }
    }
    // output << "surf(x,y,z);" << std::endl;
    // output << "shading interp" << std::endl;
    output << "[x,y]=meshgrid(x,y);" << std::endl;
    output << " T = delaunay(x,y);" << std::endl;
    output << "trisurf(T,x,y,z);" << std::endl;
    output.close();
}

template <typename RegularMesh_Type, typename Element_Type>
void Regularmesh2D_OpenDx_plot(RegularMesh_Type &mesh, Solution &solution, const std::string &filename, unsigned int multiple)
{
    /// Output the solution to file in OpenDx format.
    std::ofstream output((filename + ".dx").c_str());
    size_t nx = mesh.get_nx(), ny = mesh.get_ny();
    Element_Type ref_ele;
    nx *= multiple;
    ny *= multiple;
    RegularMesh meshtmp;
    meshtmp.set_nxny(nx, ny);
    size_t n_node = meshtmp.get_n_points();
    size_t n_ele = meshtmp.get_n_grids();
    int n_dof = ref_ele.get_n_dofs();

    output << "object 1 class array type float rank 1 shape 2 item "
           << n_node
           << " data follows" << std::endl;
    output.setf(std::ios::fixed);

    output.precision(20);
    for (size_t ni = 0; ni < n_node; ni++)
    {
        Point<2, double> pnt = meshtmp.get_point(ni);
        output << pnt[0] << "\t" << pnt[1] << std::endl;
    }

    output << std::endl;
    output << "object 2 class array type int rank 1 shape 3 item " << n_ele << " data follows" << std::endl;

    for (size_t ei = 0; ei < n_ele; ei++)
    {
        const Geometry &the_ele = meshtmp.get_grid(ei);
        output << the_ele.get_vertex(0) << "\t" << the_ele.get_vertex(1) << "\t" << the_ele.get_vertex(2) << std::endl;
    }
    output << "attribute \"element type\" string \"triangles\"" << std::endl;
    output << "attribute \"ref\" string \"positions\"" << std::endl;
    output << std::endl;
    output << "object 3 class array type float rank 0 item "
           << n_node
           << " data follows" << std::endl;

    double hx = meshtmp.get_hx(), hy = meshtmp.get_hy();
    size_t last_ei = 999;
    Geometry the_ele;
    std::valarray<size_t> dof;
    for (size_t ni = 0; ni < n_node; ni++)
    {
        size_t ix = (ni) % (nx + 1); //index of node
        size_t iy = (ni) / (nx + 1);
        Point<2, double> global_pnt({ix * hx, iy * hy});
        int is_upper = (ix % multiple + iy % multiple) > multiple; //wheater node in the upper triangle
        ix /= multiple;                                            //index of element
        iy /= multiple;
        if (ix == nx/multiple) //right boundary,upper boundary
        {
            ix--;
            is_upper = 1;
        }
        if (iy == ny/multiple)
        {
            iy--;
            is_upper = 1;
        }

        size_t ei = (2 * ix) + iy * (2 * nx / multiple) + is_upper;
        if (ei != last_ei)
        {
            the_ele = mesh.get_grid(ei);
            dof = the_ele.get_dof();
            for (size_t k = 0; k < n_dof; k++)
            {
                ref_ele.set_global_dof(k, mesh.get_dof(dof[k]));
            }
            last_ei = ei;
        }
        Point<2, double> local_pnt = ref_ele.global2local(global_pnt);
        double u_sol = 0;
        for (size_t i = 0; i < n_dof; i++)
        {
            double phi_i = ref_ele.basis_function(i, local_pnt);
            u_sol += phi_i * solution[dof[i]];
        }
        output << u_sol << std::endl;
    }

    output << "attribute \"dep\" string \"positions\"" << std::endl;
    output << std::endl;
    output << "object \"FEMFunction-2d\" class field" << std::endl;
    output << "component \"positions\" value 1" << std::endl;
    output << "component \"connections\" value 2" << std::endl;
    output << "component \"data\" value 3" << std::endl;
    output << "end" << std::endl;
    output.close();
}
